{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# build ofa resnet50\n",
    "from ofa.model_zoo import ofa_net\n",
    "ofa_network = ofa_net('ofa_resnet50', pretrained=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded checkpoint from /Users/hancai/.ofa/ofa_resnet50_acc_predictor.pth\n",
      "The accuracy predictor is ready!\n",
      "AccuracyPredictor(\n",
      "  (layers): Sequential(\n",
      "    (0): Sequential(\n",
      "      (0): Linear(in_features=82, out_features=400, bias=True)\n",
      "      (1): ReLU(inplace=True)\n",
      "    )\n",
      "    (1): Sequential(\n",
      "      (0): Linear(in_features=400, out_features=400, bias=True)\n",
      "      (1): ReLU(inplace=True)\n",
      "    )\n",
      "    (2): Sequential(\n",
      "      (0): Linear(in_features=400, out_features=400, bias=True)\n",
      "      (1): ReLU(inplace=True)\n",
      "    )\n",
      "    (3): Linear(in_features=400, out_features=1, bias=False)\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "# accuracy predictor\n",
    "import torch\n",
    "from ofa.nas.accuracy_predictor import AccuracyPredictor, ResNetArchEncoder\n",
    "from ofa.utils import download_url\n",
    "\n",
    "image_size_list = [128, 144, 160, 176, 192, 224, 240, 256]\n",
    "arch_encoder = ResNetArchEncoder(\n",
    "\timage_size_list=image_size_list, depth_list=ofa_network.depth_list, expand_list=ofa_network.expand_ratio_list,\n",
    "    width_mult_list=ofa_network.width_mult_list, base_depth_list=ofa_network.BASE_DEPTH_LIST\n",
    ")\n",
    "\n",
    "acc_predictor_checkpoint_path = download_url(\n",
    "    'https://hanlab.mit.edu/files/OnceForAll/tutorial/ofa_resnet50_acc_predictor.pth',\n",
    "    model_dir='~/.ofa/',\n",
    ")\n",
    "device = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n",
    "acc_predictor = AccuracyPredictor(arch_encoder, 400, 3,\n",
    "                                  checkpoint_path=acc_predictor_checkpoint_path, device=device)\n",
    "\n",
    "print('The accuracy predictor is ready!')\n",
    "print(acc_predictor)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "outputs": [],
   "source": [
    "# build efficiency predictor\n",
    "from ofa.nas.efficiency_predictor import ResNet50FLOPsModel\n",
    "\n",
    "efficiency_predictor = ResNet50FLOPsModel(ofa_network)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 \t tensor([0.8015], grad_fn=<AddBackward0>) \t 1367.1M MACs\n",
      "1 \t tensor([0.8080], grad_fn=<AddBackward0>) \t 1677.4M MACs\n",
      "2 \t tensor([0.8004], grad_fn=<AddBackward0>) \t 1374.6M MACs\n",
      "3 \t tensor([0.8227], grad_fn=<AddBackward0>) \t 4557.3M MACs\n",
      "4 \t tensor([0.8187], grad_fn=<AddBackward0>) \t 2547.1M MACs\n",
      "5 \t tensor([0.8014], grad_fn=<AddBackward0>) \t 1532.8M MACs\n",
      "6 \t tensor([0.8099], grad_fn=<AddBackward0>) \t 1460.4M MACs\n",
      "7 \t tensor([0.8224], grad_fn=<AddBackward0>) \t 2745.0M MACs\n",
      "8 \t tensor([0.8083], grad_fn=<AddBackward0>) \t 1672.9M MACs\n",
      "9 \t tensor([0.7812], grad_fn=<AddBackward0>) \t 959.9M MACs\n"
     ]
    }
   ],
   "source": [
    "# search\n",
    "import random\n",
    "\n",
    "for i in range(10):\n",
    "    subnet_config = ofa_network.sample_active_subnet()\n",
    "    image_size = random.choice(image_size_list)\n",
    "    subnet_config.update({'image_size': image_size})\n",
    "    predicted_acc = acc_predictor.predict_acc([subnet_config])\n",
    "    predicted_efficiency = efficiency_predictor.get_efficiency(subnet_config)\n",
    "\n",
    "    print(i, '\\t', predicted_acc, '\\t', '%.1fM MACs' % predicted_efficiency)\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
